package io.xmu.dataanalysis.model;

import java.util.HashSet;
import java.util.Iterator;

/**
 * Created by Jim Chen at XMU on 2017/5/7.
 */
public class RFM {
//    int count = 1500;
    int k = 5;
    public double[] W;
    public double[] R;
    public double[] F;
    public double[] P;
    public int[] ID;

    // R F P 三个因素的指标 [WR,WF,WM]=[0.221,0.341,0.439]
    public double[] getW() {
        return new double[] { 0.221, 0.341, 0.439 };
    }

    // 将各个数据规范化
    public void normalize(double data[]) {
        int n = data.length;
        double max = data[0];
        double min = data[0];
        for (int i = 0; i < n; i++) {
            if (max < data[i])
                max = data[i];
            if (min > data[i])
                min = data[i];
        }
        double p = max - min;
        for (int i = 0; i < n; i++) {
            data[i] = (data[i] - min) / p;
        }
    }

    public HashSet<Triple>[] sets;
    public double[] average;
    public double[] averageR;
    public double[] averageF;
    public double[] averageP;

    // 初始化分成k类
    @SuppressWarnings("unchecked")
    public void classify() {
        sets = new HashSet[k];
        for (int i = 0; i < k; i++) {
            sets[i] = new HashSet<>();
        }
        int n = R.length;
        int groupSize = n / k;
        int i = 0;
        int j = 0;
        for (; i < n; i++) {
            sets[j].add(new Triple(ID[i], R[i], F[i], P[i]));
            j = i / groupSize;
            if (j == k) {
                j = k - 1;
            }
        }
    }

    // 获得每个分组R F P 的平均
    public void getAverage() {
        average = new double[sets.length];
        for (int i = 0; i < average.length; i++) {
            int size = sets[i].size();
            double sum = 0;
            for (Triple t : sets[i]) {
                sum += (W[0] * t.r + W[1] * t.f + W[2] * t.p);
            }
            average[i] = sum / size;
        }
    }

    public boolean adjust() {
        boolean adjusted = false;
        int n = sets.length;
        for (int i = 0; i < n; i++) {
            Iterator<Triple> it = sets[i].iterator();
            while (it.hasNext()) {
                Triple t = it.next();
                double v = W[0] * t.r + W[1] * t.f + W[2] * t.p;
                double min = Math.abs(v - average[0]);
                int minIndex = 0;
                for (int j = 1; j < average.length; j++) {
                    double tmp = Math.abs(v - average[j]);
                    if (min > tmp) {
                        min = tmp;
                        minIndex = j;
                    }
                }
                if (minIndex != i) {
                    it.remove();
                    sets[minIndex].add(t);
                    adjusted = true;
                }

            }
        }
        return adjusted;
    }

    // 获得每个分组的R的平均值
    public void getAverageR() {
        averageR = new double[sets.length];
        for (int i = 0; i < averageR.length; i++) {
            int size = sets[i].size();
            double sum = 0;
            for (Triple t : sets[i]) {
                sum += t.r;
            }
            averageR[i] = sum / size;
        }
    }

    // 获得每个分组的F的平均值
    public void getAverageF() {
        averageF = new double[sets.length];
        for (int i = 0; i < averageF.length; i++) {
            int size = sets[i].size();
            double sum = 0;
            for (Triple t : sets[i]) {
                sum += t.f;
            }
            averageF[i] = sum / size;
        }
    }

    // 获得每个分组的P的平均值
    public void getAverageP() {
        averageP = new double[sets.length];
        for (int i = 0; i < averageP.length; i++) {
            int size = sets[i].size();
            double sum = 0;
            for (Triple t : sets[i]) {
                sum += t.p;
            }
            averageP[i] = sum / size;
        }
    }

    public void setP(double[] P){
        this.P = P;
    }

    public void setR(double[] R){
        this.R = R;
    }

    public void setF(double[] F){
        this.F = F;
    }

    public void setID(int[] ID){
        this.ID = ID;
    }

    public void progress() {
        W = getW();

        normalize(R);
        normalize(F);
        normalize(P);

        // 分成初始k类
        classify();
        getAverage();

        // 求每一类的平均值
        while (adjust()) {
            getAverage();
        }

        getAverageR();
        getAverageF();
        getAverageP();

        for(int i=0;i<k;i++) {
            for (int j = 0; j < k - i - 1; j++) {
                if (average[j] < average[j + 1]) {
                    HashSet<Triple> tmp = sets[j];
                    sets[j] = sets[j + 1];
                    sets[j + 1] = tmp;
                    double tmp1 = average[j];
                    average[j] = average[j + 1];
                    average[j + 1] = tmp1;
                }
            }
        }

    }

    public static class Triple {
       public int id;
       public double r;
       public double f;
       public double p;

        public Triple(int id, double r, double f, double p) {
            this.id = id;
            this.r = r;
            this.f = f;
            this.p = p;
        }

        @Override
        public int hashCode() {
            final int prime = 31;
            int result = 1;
            result = prime * result + id;
            return result;
        }

        @Override
        public boolean equals(Object obj) {
            if (this == obj)
                return true;
            if (obj == null)
                return false;
            if (getClass() != obj.getClass())
                return false;
            Triple other = (Triple) obj;
            if (id != other.id)
                return false;
            return true;
        }

        @Override
        public String toString() {
            return "[id=" + id + ", r=" + r + ", f=" + f + ", p=" + p + "]";
        }
    }

}
